package cn.doitedu.ml.gender

import cn.doitedu.commons.util.SparkUtil
import cn.doitedu.ml.util.VecUtil.arr2Vec
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.mllib.classification.SVMWithSGD
import org.apache.spark.mllib.linalg
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row

/**
 * @date: 2020/2/23
 * @site: www.doitedu.cn
 * @author: hunter.d 涛哥
 * @qq: 657270652
 * @description: 用户行为性别画像标签预测
  *        所用算法：朴素贝叶斯分类算法
 */
object GenderPredict {


  def main(args: Array[String]): Unit = {
    Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
    val spark = SparkUtil.getSparkSession(this.getClass.getSimpleName)


    // 加载样本数据
    val sample = spark.read.option("header","true").option("inferSchema",true).csv("userprofile/data/gender/sample")
    sample.show(100,false)

    // 特征处理
    // 特征字段：category1,category2,category3,brand1,brand2,brand3,day30_buy_cnts,day30_buy_amt

    // 将 day30_buy_cnts,day30_buy_amt 进行离散化处理，并封装成向量
    sample.createTempView("sample")
    spark.udf.register("arr2Vec",arr2Vec)
    val features = spark.sql(
      """
        |select
        |label,
        |arr2vec(
        |array(
        |category1,
        |category2,
        |category3,
        |brand1,
        |brand2,
        |brand3,
        |case  when day30_buy_cnts<10 then 0.0
        |      when day30_buy_cnts>=10 and day30_buy_cnts<20 then 1.0
        |      when day30_buy_cnts>=20 and day30_buy_cnts<30 then 2.0
        |      when day30_buy_cnts>=30 and day30_buy_cnts<40 then 3.0
        |      else 4.0
        |end,
        |case  when day30_buy_amt<100 then 0.0
        |      when day30_buy_amt>=100 and day30_buy_amt<200 then 1.0
        |      when day30_buy_amt>=200 and day30_buy_amt<500 then 2.0
        |      when day30_buy_amt>=500 and day30_buy_amt<1000 then 3.0
        |      when day30_buy_amt>=1000 and day30_buy_amt<2000 then 4.0
        |      when day30_buy_amt>=2000 and day30_buy_amt<5000 then 5.0
        |      else 6.0
        |end
        |)
        |) as features
        |
        |from sample
        |
      """.stripMargin)

    features.show(100,false)


    // 构造朴素贝叶斯算法工具
    val bayes = new NaiveBayes()
        .setFeaturesCol("features")
        .setLabelCol("label")
        .setSmoothing(1.0)
    // 训练模型
    val bayes_model = bayes.fit(features)


    // 构造支持向量机分类算法工具（SVM）
    val labledPointFeatures = features.rdd.map({
      case Row(label:Double,features:org.apache.spark.ml.linalg.Vector)
        => {
        // 将ml中的Vector转成mllib中的Vector
        val arr: Array[Double] = features.toArray
        val vec = Vectors.dense(arr)
        LabeledPoint(label,vec)
      }
    })
    // 训练svm模型
    val svm_model = SVMWithSGD.train(labledPointFeatures,20)


    // 加载测试数据
    val test = spark.read.option("header","true").option("inferSchema",true).csv("userprofile/data/gender/test")
    test.show(100,false)
    test.createTempView("test")
    val test_features = spark.sql(
      """
        |select
        |label,
        |arr2vec(
        |array(
        |category1,
        |category2,
        |category3,
        |brand1,
        |brand2,
        |brand3,
        |case  when day30_buy_cnts<10 then 0.0
        |      when day30_buy_cnts>=10 and day30_buy_cnts<20 then 1.0
        |      when day30_buy_cnts>=20 and day30_buy_cnts<30 then 2.0
        |      when day30_buy_cnts>=30 and day30_buy_cnts<40 then 3.0
        |      else 4.0
        |end,
        |case  when day30_buy_amt<100 then 0.0
        |      when day30_buy_amt>=100 and day30_buy_amt<200 then 1.0
        |      when day30_buy_amt>=200 and day30_buy_amt<500 then 2.0
        |      when day30_buy_amt>=500 and day30_buy_amt<1000 then 3.0
        |      when day30_buy_amt>=1000 and day30_buy_amt<2000 then 4.0
        |      when day30_buy_amt>=2000 and day30_buy_amt<5000 then 5.0
        |      else 6.0
        |end
        |)
        |) as features
        |
        |from test
        |
      """.stripMargin)
    // 用朴素贝叶斯模型对测试数据进行性别预测
    val predict1 = bayes_model.transform(test_features)
    predict1.show(100,false)


    // 将测试数据变成mllib的LabeledPoint形式，用svm来预测
    val test_vec: RDD[linalg.Vector] = test_features.rdd.map({
      case Row(label:Double,features:org.apache.spark.ml.linalg.Vector)
      => {
        // 将ml中的Vector转成mllib中的Vector
        val arr: Array[Double] = features.toArray
        val vec = Vectors.dense(arr)
        vec
      }
    })
    val predict2: RDD[Double] = svm_model.predict(test_vec)
    predict2.take(100).foreach(println)

    spark.close()

  }

}
